{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a8e986cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Added '/home/mazumdera/maxtext' to sys.path\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "from MaxText.globals import MAXTEXT_REPO_ROOT\n",
    "\n",
    "# Add the project root to the system path if it's not already there\n",
    "if MAXTEXT_REPO_ROOT not in sys.path:\n",
    "  sys.path.insert(0, MAXTEXT_REPO_ROOT)\n",
    "  print(f\"Added '{MAXTEXT_REPO_ROOT}' to sys.path\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0ab2e1dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-06-18 21:34:12.489183: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1750282452.508183 1726814 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1750282452.513660 1726814 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "W0000 00:00:1750282452.528073 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1750282452.528091 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1750282452.528093 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1750282452.528094 1726814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n"
     ]
    }
   ],
   "source": [
    "import MaxText as mt\n",
    "from MaxText import pyconfig\n",
    "from MaxText import maxtext_utils\n",
    "import numpy as np\n",
    "from MaxText.input_pipeline import _input_pipeline_utils\n",
    "import os\n",
    "from MaxText import max_logging\n",
    "from MaxText import common_types\n",
    "import jax\n",
    "from MaxText import inference_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d2de93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Updating keys from env and command line: ['run_name', 'enable_checkpointing', 'base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_num_decoder_layers', 'per_device_batch_size', 'max_target_length', 'max_prefill_predict_length']\n",
      "Running Model: default\n",
      "Updating keys from model: []\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:2025-06-18 21:34:16,611:jax._src.xla_bridge:913: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n",
      "WARNING:jax._src.xla_bridge:A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period, use_replicator_service and replicator_backup_interval_minutes\n",
      "dataset_type set to tfds, will use keys['dataset_path']='' and keys['dataset_name']='c4/en:3.0.1'\n",
      "Config param activations_in_float32: False\n",
      "Config param adam_b1: 0.9\n",
      "Config param adam_b2: 0.95\n",
      "Config param adam_eps: 1e-08\n",
      "Config param adam_eps_root: 0.0\n",
      "Config param adam_weight_decay: 0.1\n",
      "Config param add_bos: True\n",
      "Config param add_eos: True\n",
      "Config param allow_split_physical_axes: False\n",
      "Config param ar_cache_axis_order: 1,2,0,3\n",
      "Config param async_checkpointing: True\n",
      "Config param attention: autoselected\n",
      "Config param attention_type: global\n",
      "Config param attn_logits_soft_cap: None\n",
      "Config param autoregressive_decode_assert: \n",
      "Config param base_emb_dim: 256\n",
      "Config param base_mlp_dim: 7168\n",
      "Config param base_moe_mlp_dim: 7168\n",
      "Config param base_num_decoder_layers: 2\n",
      "Config param base_num_kv_heads: 2\n",
      "Config param base_num_query_heads: 2\n",
      "Config param base_output_directory: \n",
      "Config param beta_fast: 32\n",
      "Config param beta_slow: 1\n",
      "Config param capacity_factor: -1.0\n",
      "Config param cast_logits_to_fp32: True\n",
      "Config param checkpoint_dir: test/checkpoints/\n",
      "Config param checkpoint_is_quantized: False\n",
      "Config param checkpoint_period: 10000\n",
      "Config param checkpoint_storage_concurrent_gb: 96\n",
      "Config param checkpoint_storage_target_data_file_size_bytes: 2147483648\n",
      "Config param checkpoint_storage_use_ocdbt: True\n",
      "Config param checkpoint_storage_use_zarr3: True\n",
      "Config param chunk_attn_window_size: 0\n",
      "Config param collect_stack_trace: False\n",
      "Config param colocated_python_data_input: False\n",
      "Config param compile_topology: \n",
      "Config param compile_topology_num_slices: -1\n",
      "Config param compiled_trainstep_file: \n",
      "Config param compute_axis_order: 0,1,2,3\n",
      "Config param context: remat\n",
      "Config param context_parallel_load_balance: True\n",
      "Config param cosine_learning_rate_final_fraction: 0.1\n",
      "Config param custom_mesh: \n",
      "Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'),)\n",
      "Config param data_shuffle_seed: 0\n",
      "Config param dataset_name: c4/en:3.0.1\n",
      "Config param dataset_path: \n",
      "Config param dataset_type: tfds\n",
      "Config param dcn_autoregressive_parallelism: 1\n",
      "Config param dcn_context_autoregressive_parallelism: 1\n",
      "Config param dcn_context_parallelism: 1\n",
      "Config param dcn_data_parallelism: -1\n",
      "Config param dcn_expert_parallelism: 1\n",
      "Config param dcn_fsdp_parallelism: 1\n",
      "Config param dcn_fsdp_transpose_parallelism: 1\n",
      "Config param dcn_parallelism: [-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "Config param dcn_pipeline_parallelism: 1\n",
      "Config param dcn_sequence_parallelism: 1\n",
      "Config param dcn_tensor_parallelism: 1\n",
      "Config param dcn_tensor_sequence_parallelism: 1\n",
      "Config param dcn_tensor_transpose_parallelism: 1\n",
      "Config param decode_sampling_nucleus_p: -1\n",
      "Config param decode_sampling_strategy: greedy\n",
      "Config param decode_sampling_temperature: 1.0\n",
      "Config param decode_sampling_top_k: 0\n",
      "Config param decoder_block: DecoderBlockType.LLAMA2\n",
      "Config param decoder_layer_input: device\n",
      "Config param dpo_beta: 0.1\n",
      "Config param dpo_label_smoothing: 0.0\n",
      "Config param dropout_rate: 0.0\n",
      "Config param dtype: bfloat16\n",
      "Config param dtype_mm: float32\n",
      "Config param dump_hlo: False\n",
      "Config param dump_hlo_delete_local_after: True\n",
      "Config param dump_hlo_gcs_dir: \n",
      "Config param dump_hlo_local_dir: /tmp/xla_dump/\n",
      "Config param dump_hlo_module_name: jit_train_step\n",
      "Config param dump_hlo_upload_all: False\n",
      "Config param dump_hlo_xla_flags: \n",
      "Config param dump_step: -1\n",
      "Config param emb_dim: 256\n",
      "Config param enable_checkpoint_cloud_logger: False\n",
      "Config param enable_checkpointing: False\n",
      "Config param enable_data_shuffling: True\n",
      "Config param enable_dropout: True\n",
      "Config param enable_emergency_checkpoint: False\n",
      "Config param enable_gcp_goodput_metrics: True\n",
      "Config param enable_gcp_step_deviation_metrics: True\n",
      "Config param enable_goodput_recording: False\n",
      "Config param enable_jax_profiler: False\n",
      "Config param enable_llm_inference_pool: False\n",
      "Config param enable_model_warmup: False\n",
      "Config param enable_padding_causal_mask: True\n",
      "Config param enable_pathways_goodput: False\n",
      "Config param enable_prefix_caching: False\n",
      "Config param enable_single_controller: False\n",
      "Config param enable_single_replica_ckpt_restoring: False\n",
      "Config param enable_tensorboard: True\n",
      "Config param eval_data_columns: ['text']\n",
      "Config param eval_dataset_name: c4/en:3.0.1\n",
      "Config param eval_interval: -1\n",
      "Config param eval_per_device_batch_size: 1.0\n",
      "Config param eval_split: validation\n",
      "Config param eval_steps: -1\n",
      "Config param expansion_factor_real_data: -1\n",
      "Config param final_logits_soft_cap: None\n",
      "Config param first_num_dense_layers: 0\n",
      "Config param float32_logits: False\n",
      "Config param float32_qk_product: False\n",
      "Config param force_unroll: False\n",
      "Config param freeze_vision_encoder_params: True\n",
      "Config param fused_mlp: False\n",
      "Config param fused_qkv: False\n",
      "Config param gcs_metrics: False\n",
      "Config param generate_slice: v5e-16\n",
      "Config param global_batch_size_to_eval_on: 1\n",
      "Config param global_batch_size_to_load: 1\n",
      "Config param global_batch_size_to_load_eval: 1\n",
      "Config param global_batch_size_to_train_on: 1\n",
      "Config param global_parameter_scale: 1\n",
      "Config param goodput_upload_interval_seconds: 30\n",
      "Config param gradient_accumulation_steps: 1\n",
      "Config param gradient_clipping_threshold: 1.0\n",
      "Config param grain_eval_files: \n",
      "Config param grain_file_type: arrayrecord\n",
      "Config param grain_train_files: \n",
      "Config param grain_worker_count: 1\n",
      "Config param grain_worker_count_eval: 1\n",
      "Config param hardware: tpu\n",
      "Config param head_dim: 128\n",
      "Config param heartbeat_reporting_interval_in_seconds: 5\n",
      "Config param hf_data_dir: \n",
      "Config param hf_eval_files: \n",
      "Config param hf_eval_split: \n",
      "Config param hf_path: \n",
      "Config param hf_train_files: \n",
      "Config param hidden_size_for_vit: 1408\n",
      "Config param ici_autoregressive_parallelism: 1\n",
      "Config param ici_context_autoregressive_parallelism: 1\n",
      "Config param ici_context_parallelism: 1\n",
      "Config param ici_data_parallelism: 1\n",
      "Config param ici_expert_parallelism: 1\n",
      "Config param ici_fsdp_parallelism: -1\n",
      "Config param ici_fsdp_transpose_parallelism: 1\n",
      "Config param ici_parallelism: [1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "Config param ici_pipeline_parallelism: 1\n",
      "Config param ici_sequence_parallelism: 1\n",
      "Config param ici_tensor_parallelism: 1\n",
      "Config param ici_tensor_sequence_parallelism: 1\n",
      "Config param ici_tensor_transpose_parallelism: 1\n",
      "Config param image_path: \n",
      "Config param image_size_for_vit: 896\n",
      "Config param inference_benchmark_test: False\n",
      "Config param inference_metadata_file: \n",
      "Config param inference_microbenchmark_log_file_path: \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Config param inference_microbenchmark_loop_iters: 10\n",
      "Config param inference_microbenchmark_num_samples: [1, 2, 3, 4, 5]\n",
      "Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
      "Config param inference_microbenchmark_stages: prefill,generate\n",
      "Config param inference_server: MaxtextInterleavedServer\n",
      "Config param inhomogeneous_layer_cycle_interval: 1\n",
      "Config param init_weights_seed: 0\n",
      "Config param input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']\n",
      "Config param interleave_moe_layer_step: 1\n",
      "Config param intermediate_size_for_vit: 5632\n",
      "Config param jax_cache_dir: ~/jax_cache\n",
      "Config param jax_debug_log_modules: \n",
      "Config param jax_distributed_initialization_timeout: 300\n",
      "Config param jax_profiler_port: 9999\n",
      "Config param key_proj: remat\n",
      "Config param kv_lora_rank: 512\n",
      "Config param kv_quant_axis: heads_and_dkv\n",
      "Config param kv_quant_dtype: int8\n",
      "Config param learning_rate: 3e-05\n",
      "Config param learning_rate_schedule_steps: 150001\n",
      "Config param load_balance_loss_weight: 0.01\n",
      "Config param load_from_prefill_dir: False\n",
      "Config param load_full_state_path: \n",
      "Config param load_parameters_path: \n",
      "Config param local_checkpoint_directory: \n",
      "Config param local_checkpoint_period: 0\n",
      "Config param local_rope_max_timescale: -1\n",
      "Config param log_config: True\n",
      "Config param log_period: 100\n",
      "Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_batch_no_exp', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('data', 'stage', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive')), ('activation_kv_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_length', ('sequence', 'context')), ('activation_length', ('context',)), ('activation_norm_length', ('tensor_sequence', 'context', 'sequence')), ('activation_q_length', ('context',)), ('activation_kv_length', ()), ('activation_embed', ('tensor', 'tensor_transpose')), ('activation_mlp', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_kv', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_prefill_kv_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_kv_head_dim', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_vocab', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_vocab', ('tensor', 'tensor_transpose')), ('activation_vocab', 'tensor_sequence'), ('activation_vocab', ('sequence', 'context')), ('activation_stage', 'stage'), ('activation_exp', ('expert',)), ('decode_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('decode_length', ('sequence',)), ('mlp', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), ('vocab', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('q_heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('kv_heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert')), ('embed', ('fsdp', 'sequence', 'tensor_transpose', 'context', 'expert')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('embed', ('fsdp', 'sequence', 'context', 'expert')), ('embed_no_exp', ('fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context')), ('embed_no_exp', ('fsdp', 'sequence', 'tensor_transpose', 'context')), ('embed_no_exp', ('fsdp', 'fsdp_transpose', 'sequence', 'context')), ('embed_no_exp', ('fsdp', 'sequence', 'context')), ('q_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert')), ('q_lora', ('fsdp', 'sequence', 'context', 'tensor_transpose', 'expert')), ('q_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('q_lora', ('fsdp', 'sequence', 'context', 'expert')), ('kv_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert')), ('kv_lora', ('fsdp', 'sequence', 'context', 'tensor_transpose', 'expert')), ('kv_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('kv_lora', ('fsdp', 'sequence', 'context', 'expert')), ('norm', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('layers', 'stage'), ('kv', ()), ('kv_head_dim', ()), ('cache_batch_prefill', ()), ('cache_batch', ()), ('cache_heads_none', ()), ('cache_heads', ('autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence')), ('cache_heads', ('autoregressive', 'tensor', 'tensor_sequence')), ('cache_kv', ()), ('cache_sequence', ()), ('exp', 'expert'), ('paged_kv_heads', ('tensor',)), ('num_pages', ()), ('tokens_per_page', ()), ('paged_kv_head_dim_size', ()))\n",
      "Config param logits_dot_in_fp32: False\n",
      "Config param logits_via_embedding: False\n",
      "Config param lora_input_adapters_path: \n",
      "Config param matmul_precision: default\n",
      "Config param max_checkify: False\n",
      "Config param max_corpus_chars: 10000000\n",
      "Config param max_position_embeddings: 163840\n",
      "Config param max_prefill_predict_length: 4\n",
      "Config param max_target_length: 4\n",
      "Config param megablox: True\n",
      "Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']\n",
      "Config param metrics_dir: test/metrics/\n",
      "Config param metrics_file: \n",
      "Config param micro_batch_size_to_eval_on: 1\n",
      "Config param micro_batch_size_to_train_on: 1\n",
      "Config param mla_naive_kvcache: True\n",
      "Config param mlp_activations: ['silu', 'linear']\n",
      "Config param mlp_dim: 7168\n",
      "Config param mlpwi: remat\n",
      "Config param mlpwi_0: remat\n",
      "Config param mlpwi_1: remat\n",
      "Config param mlpwo: remat\n",
      "Config param model_call_mode: \n",
      "Config param model_name: default\n",
      "Config param moe_mlp_dim: 7168\n",
      "Config param monitor_goodput: False\n",
      "Config param monitor_step_time_deviation: True\n",
      "Config param mscale: 1.0\n",
      "Config param mu_dtype: float32\n",
      "Config param multi_sampling: False\n",
      "Config param n_routing_groups: -1\n",
      "Config param nope_layer_interval: -1\n",
      "Config param normalization_layer_epsilon: 1e-05\n",
      "Config param normalize_embedding_logits: True\n",
      "Config param num_attention_heads_for_vit: 16\n",
      "Config param num_channels_for_vit: 3\n",
      "Config param num_decoder_layers: 2\n",
      "Config param num_epoch: 1\n",
      "Config param num_experts: 1\n",
      "Config param num_experts_per_tok: 1\n",
      "Config param num_hidden_layers_for_vit: 34\n",
      "Config param num_kv_heads: 2\n",
      "Config param num_layers_per_pipeline_stage: 1\n",
      "Config param num_pipeline_microbatches: -1\n",
      "Config param num_pipeline_repeats: -1\n",
      "Config param num_query_heads: 2\n",
      "Config param num_slices: 1\n",
      "Config param opt_type: adamw\n",
      "Config param optimize_mesh_for_tpu_v6e: False\n",
      "Config param optimizer_memory_host_offload: False\n",
      "Config param original_max_position_embeddings: 4096\n",
      "Config param out_proj: remat\n",
      "Config param override_model_config: False\n",
      "Config param packing: True\n",
      "Config param pagedattn_max_pages_per_group: 1\n",
      "Config param pagedattn_num_pages: 64\n",
      "Config param pagedattn_pages_per_compute_block: 4\n",
      "Config param pagedattn_tokens_per_page: 32\n",
      "Config param param_scan_axis: 1\n",
      "Config param parameter_memory_host_offload: False\n",
      "Config param patch_size_for_vit: 14\n",
      "Config param per_device_batch_size: 1.0\n",
      "Config param pipeline_delay_activation_forwarding: False\n",
      "Config param pipeline_fsdp_ag_once: False\n",
      "Config param pipeline_parallel_layers: -1\n",
      "Config param pixel_shuffle_ratio_for_vit: 0.5\n",
      "Config param prefill_cache_axis_order: 1,2,0,3\n",
      "Config param prefill_cache_dir: \n",
      "Config param prefill_chunk_size: 256\n",
      "Config param prefill_slice: v5e-16\n",
      "Config param prefix_caching_dram_byte: 100000000000\n",
      "Config param prefix_caching_hbm_byte: 10000000000\n",
      "Config param profile_cleanly: True\n",
      "Config param profile_periodically_period: -1\n",
      "Config param profiler: \n",
      "Config param profiler_steps: 5\n",
      "Config param projector_dropout_for_vit: 0.0\n",
      "Config param projector_input_dim_for_vit: 4096\n",
      "Config param projector_output_dim_for_vit: 4096\n",
      "Config param prometheus_port: 0\n",
      "Config param prompt: I love to\n",
      "Config param q_lora_rank: 0\n",
      "Config param qk_nope_head_dim: 128\n",
      "Config param qk_rope_head_dim: 64\n",
      "Config param qkv_proj: remat\n",
      "Config param quant_cfg_path: \n",
      "Config param quantization: \n",
      "Config param quantization_local_shard_count: 1\n",
      "Config param quantize_kvcache: False\n",
      "Config param query_proj: remat\n",
      "Config param ragged_block_size: 256\n",
      "Config param record_internal_nn_metrics: 0\n",
      "Config param remat_policy: full\n",
      "Config param remat_policy_for_vit: minimal\n",
      "Config param replicate_quant_scale: False\n",
      "Config param replicator_backup_interval_minutes: 0\n",
      "Config param report_heartbeat_metric_for_gcp_monitoring: False\n",
      "Config param report_performance_metric_for_gcp_monitoring: False\n",
      "Config param reshape_q: False\n",
      "Config param return_log_prob: False\n",
      "Config param reuse_example_batch: 0\n",
      "Config param rope_factor: 40\n",
      "Config param rope_max_timescale: 10000\n",
      "Config param rope_min_timescale: 1\n",
      "Config param rope_theta_for_vit: 10000\n",
      "Config param rope_type: default\n",
      "Config param rope_use_scale: True\n",
      "Config param routed_bias: False\n",
      "Config param routed_scaling_factor: 1.0\n",
      "Config param routed_score_func: \n",
      "Config param run_name: test\n",
      "Config param sa_block_kv: 512\n",
      "Config param sa_block_kv_compute: 512\n",
      "Config param sa_block_kv_dkv: 512\n",
      "Config param sa_block_kv_dkv_compute: 512\n",
      "Config param sa_block_kv_dq: 512\n",
      "Config param sa_block_q: 512\n",
      "Config param sa_block_q_dkv: 512\n",
      "Config param sa_block_q_dq: 512\n",
      "Config param sa_k_layout: HEAD_DIM_MINOR\n",
      "Config param sa_q_layout: HEAD_DIM_MINOR\n",
      "Config param sa_use_fused_bwd_kernel: False\n",
      "Config param sa_v_layout: HEAD_DIM_MINOR\n",
      "Config param save_config_to_gcs: False\n",
      "Config param save_quantized_params_path: \n",
      "Config param scan_layers: True\n",
      "Config param scan_layers_per_stage: False\n",
      "Config param scan_pipeline_iterations: True\n",
      "Config param set_remat_policy_on_layers_per_stage: False\n",
      "Config param set_remat_policy_on_pipeline_iterations: True\n",
      "Config param sft_train_on_completion_only: False\n",
      "Config param sharding_tolerance: 0.02\n",
      "Config param shared_experts: 1\n",
      "Config param skip_first_n_steps_for_profiler: 1\n",
      "Config param skip_jax_distributed_system: False\n",
      "Config param sliding_window_size: 0\n",
      "Config param sparse_matmul: True\n",
      "Config param stack_prefill_result_cache: False\n",
      "Config param stack_trace_interval_seconds: 600\n",
      "Config param stack_trace_to_cloud: False\n",
      "Config param step_deviation_interval_seconds: 30\n",
      "Config param steps: 150001\n",
      "Config param target_eval_loss: 0.0\n",
      "Config param temperature_tuning: False\n",
      "Config param tensorboard_dir: test/tensorboard/\n",
      "Config param tile_activation_dim: 1024\n",
      "Config param tile_batch_seq: 512\n",
      "Config param tile_weight_dim: 1024\n",
      "Config param tokenize_eval_data: True\n",
      "Config param tokenize_train_data: True\n",
      "Config param tokenizer_path: assets/tokenizer.llama2\n",
      "Config param tokenizer_type: sentencepiece\n",
      "Config param topk_routing_group: -1\n",
      "Config param train_data_columns: ['text']\n",
      "Config param train_split: train\n",
      "Config param trainable_position_size: -1\n",
      "Config param upload_all_profiler_results: False\n",
      "Config param use_chat_template: False\n",
      "Config param use_chunked_prefill: False\n",
      "Config param use_dpo: False\n",
      "Config param use_iota_embed: False\n",
      "Config param use_multimodal: False\n",
      "Config param use_post_attn_norm: False\n",
      "Config param use_post_ffw_norm: False\n",
      "Config param use_qk_norm: False\n",
      "Config param use_ragged_attention: False\n",
      "Config param use_random_routing: False\n",
      "Config param use_replicator_service: False\n",
      "Config param use_sft: False\n",
      "Config param use_untrainable_positional_embedding: False\n",
      "Config param use_vertex_tensorboard: False\n",
      "Config param using_pipeline_parallelism: False\n",
      "Config param v_head_dim: 128\n",
      "Config param value_proj: remat\n",
      "Config param vertex_tensorboard_project: \n",
      "Config param vertex_tensorboard_region: \n",
      "Config param vision_output_dim_for_vit: 4096\n",
      "Config param vocab_size: 32000\n",
      "Config param warmup_steps_fraction: 0.1\n",
      "Config param weight_dtype: float32\n",
      "Num_devices: 1, shape (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'global_store' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 18\u001b[0m\n\u001b[1;32m      1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyconfig\u001b[38;5;241m.\u001b[39minitialize(\n\u001b[1;32m      2\u001b[0m     [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdecode.py\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../configs/base.yml\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;66;03m#TODO: @mazumdera: why decode.py?\u001b[39;00m\n\u001b[1;32m      3\u001b[0m     per_device_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.0\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     14\u001b[0m \n\u001b[1;32m     15\u001b[0m )\n\u001b[1;32m     17\u001b[0m model \u001b[38;5;241m=\u001b[39m mt\u001b[38;5;241m.\u001b[39mfrom_pretrained(config)\n\u001b[0;32m---> 18\u001b[0m mesh, init_rng \u001b[38;5;241m=\u001b[39m \u001b[43mglobal_store\u001b[49m\u001b[38;5;241m.\u001b[39mget_global_mesh_and_init_rng()\n\u001b[1;32m     19\u001b[0m state, _ \u001b[38;5;241m=\u001b[39m maxtext_utils\u001b[38;5;241m.\u001b[39msetup_decode_state(model, config, init_rng, mesh, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'global_store' is not defined"
     ]
    }
   ],
   "source": [
    "from MaxText.globals import MAXTEXT_PKG_DIR\n",
    "\n",
    "config = pyconfig.initialize(\n",
    "    [os.path.join(MAXTEXT_PKG_DIR, \"decode.py\"), os.path.join(MAXTEXT_PKG_DIR, \"configs\", \"base.yml\")],\n",
    "    per_device_batch_size=1.0,\n",
    "    run_name=\"test\",\n",
    "    enable_checkpointing=False,\n",
    "    base_num_decoder_layers=2,\n",
    "    max_target_length=4,\n",
    "    base_emb_dim=256,\n",
    "    base_num_query_heads=2,\n",
    "    base_num_kv_heads=2,\n",
    "    max_prefill_predict_length=4,\n",
    "    # tokenizer_path=\"assets/llama3.1-tokenizer/\",\n",
    "    # model_name=\"llama3.1-7b\",\n",
    ")\n",
    "\n",
    "model = mt.from_config(config)\n",
    "mesh = model.mesh\n",
    "init_rng = jax.random.PRNGKey(config.init_weights_seed)\n",
    "state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d2d0c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenizer path: /home/mazumdera/maxtext/assets/tokenizer_llama3.tiktoken\n",
      "Reloaded tiktoken model from /home/mazumdera/maxtext/assets/tokenizer_llama3.tiktoken\n",
      "#words: 128256 - BOS ID: 128000 - EOS ID: 128001\n",
      "input_ids=[128000, 40, 3021, 311], ids=[[128000     40   3021    311]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]\n"
     ]
    }
   ],
   "source": [
    "from MaxText.globals import MAXTEXT_ASSETS_ROOT\n",
    "\n",
    "source_tokenizer = _input_pipeline_utils.get_tokenizer(\n",
    "    os.path.join(MAXTEXT_ASSETS_ROOT, \"tokenizer_llama3.tiktoken\"),\n",
    "    \"tiktoken\",\n",
    "    add_bos=True,\n",
    "    add_eos=False,\n",
    ")\n",
    "\n",
    "\n",
    "# TODO: @mazumdera: any way to geto segment and position ids like HF tokenizer gives us?\n",
    "input_ids = source_tokenizer.encode(config.prompt)  # .numpy()\n",
    "ids = np.asarray(input_ids, dtype=np.int32)\n",
    "s = (config.global_batch_size_to_train_on, config.max_target_length)\n",
    "decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR\n",
    "decoder_positions = np.stack(\n",
    "    [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]\n",
    ")\n",
    "\n",
    "# TODO: @mazumdera: simplify this config.global_batch_size_to_train_on=1\n",
    "ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])\n",
    "max_logging.log(\n",
    "    f\"input_ids={input_ids}, ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5a1fe11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[CpuDevice(id=0)]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "\n",
    "!export TPU_LIBRARY_PATH=/home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so\n",
    "\n",
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d42b156",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so\n"
     ]
    }
   ],
   "source": [
    "!ls /home/mazumdera/.local/lib/python3.10/site-packages/libtpu/libtpu.so"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7436751b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full_train_logits[0, 0, :]=array([[ 0.6484375 , -1.09375   , -1.3359375 , ...,  0.0177002 ,\n",
      "        -0.8984375 , -0.57421875],\n",
      "       [ 0.8125    , -0.53125   , -0.3125    , ...,  1.34375   ,\n",
      "         1.078125  , -1.3828125 ],\n",
      "       [ 0.6171875 , -2.        , -2.0625    , ...,  0.13867188,\n",
      "        -0.9375    , -0.796875  ],\n",
      "       [-0.27734375, -1.3203125 , -0.765625  , ...,  1.1171875 ,\n",
      "        -0.26953125,  0.4296875 ]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "import jax.experimental.multihost_utils\n",
    "\n",
    "full_train_logits = model.apply(\n",
    "    state.params,\n",
    "    ids,\n",
    "    decoder_positions,\n",
    "    decoder_segment_ids,\n",
    "    enable_dropout=False,\n",
    "    rngs={\"aqt\": init_rng},\n",
    ")\n",
    "full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)\n",
    "max_logging.log(f\"{full_train_logits[0, 0, :]=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb06c0c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_logits = jax.lax.dynamic_slice(\n",
    "    full_train_logits, (0, 0, full_train_logits.shape[2] - 1, 0), (1, 1, 1, full_train_logits.shape[3])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "308f2a57",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_rng, new_rng = jax.random.split(init_rng)\n",
    "first_generated_token = inference_utils.sampling(\n",
    "    selected_logits,\n",
    "    new_rng,\n",
    "    config.decode_sampling_strategy,\n",
    "    topk=config.decode_sampling_top_k,\n",
    "    nucleus_topp=config.decode_sampling_nucleus_p,\n",
    "    temperature=config.decode_sampling_temperature,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32555a83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "26831"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "first_generated_token.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3de52746",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'-ad'"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_tokenizer.decode([first_generated_token.item()])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
